import torch
import torch.nn as nn

x = torch.ones(1, 5)
print(x)

drop = nn.Dropout(p=0.5)
x = drop(x)
print(x)

x = drop(x)
print(x)

x = drop(x)
print(x)

x = torch.ones(1, 5)
drop = nn.Dropout(p=0.5)

for i in range(3):
    out = drop(x)
    print(out)
